import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data

import models.bottom_model_plus as models
import copy
import dill
from vfl_framework import VflFramework
import pandas as pd
from sklearn.preprocessing import StandardScaler
import numpy as np


def create_model(ema=False, size_bottom_out=10, num_classes=10):
    model = models.BottomModelPlus(size_bottom_out, num_classes)
    model = model.cuda()
    if ema:
        for param in model.parameters():
            param.detach_()
    return model


def correct_counter(output, target):
    _, pred = output.topk(1, 1, True, True)
    correct_1 = torch.eq(pred, target.view(-1, 1)).sum().float().item()
    return correct_1


def train_and_validata_model(train_loader, test_loader, model, epochs):
    loss_func = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(params=model.parameters(), lr=2e-3, momentum=0.9)
    for epoch in range(epochs):

        # train
        model.train()
        loss_epoch = 0
        count_batch = 0
        for data, labels in train_loader:
            data = data.float()
            output = model(data)
            loss = loss_func(output, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_epoch += loss
            count_batch += 1
        # validate
        model.eval()
        count_correct = 0
        for data, labels in test_loader:
            data = data.float()
            output = model(data)
            count_correct += correct_counter(output, labels)
        if epoch == epochs - 1:
            print(f'epoch {epoch}')
            print(f'train loss:{loss_epoch / count_batch}')
            print(f'accuracy on test set:{count_correct / len(test_loader.dataset.df)}')


class MyModel(nn.Module):

    def __init__(self, in_dim=14):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(in_dim, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x


class MyDataset(data.Dataset):

    def __init__(self, df, label_area_mean, use_diagnosis_as_feature=False, train=True, fraction_of_known=0.1):
        known_samples_num = int(fraction_of_known * len(df))
        label_area_mean = np.array(label_area_mean, dtype=np.longlong)
        if use_diagnosis_as_feature:
            feat_num = 15
        else:
            feat_num = 14
            df = df.drop('diagnosis', axis=1)
        self.df = df
        self.train_data = np.array(df.iloc[:known_samples_num, :feat_num])
        self.train_labels = label_area_mean[:known_samples_num]
        self.test_data = np.array(df.iloc[known_samples_num:, :feat_num])
        self.test_labels = label_area_mean[known_samples_num:]
        self.train = train

    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

    def __getitem__(self, index):
        if self.train:
            data, label = self.train_data[index], self.train_labels[index]
        else:
            data, label = self.test_data[index], self.test_labels[index]
        return data, label


model = create_model(ema=False, size_bottom_out=2, num_classes=2)
checkpoint = torch.load('D:/MyCodes/label_inference_attacks_against_vfl/saved_experiment_results/saved_models/BCW_saved_models/BCW_saved_framework_lr=0.01_normal_half=14.pth', pickle_module=dill)
model.bottom_model = copy.deepcopy(checkpoint.malicious_bottom_model_a)
path = 'D:/MyCodes/label_inference_attacks_against_vfl/saved_experiment_results/saved_models/BCW_saved_models/BCW_mc_best.pth'
load_dict = torch.load(path)
model.load_state_dict(load_dict['state_dict'])
model = model.float()

path = 'D:/Datasets/BreastCancerWisconsin/wisconsin.csv'
df = pd.read_csv(path)
df = df.drop('Unnamed: 32', axis=1)
df = df.drop('id', axis=1)
# sequence adjustment
radius_mean = df['radius_mean']
df = df.drop('radius_mean', axis=1)
df['radius_mean'] = radius_mean
perimeter_mean = df['perimeter_mean']
df = df.drop('perimeter_mean', axis=1)
df['perimeter_mean'] = perimeter_mean

sc = StandardScaler()
df[df.columns[1:]] = sc.fit_transform(df[df.columns[1:]])

feature_area_mean = df['area_mean']
df = df.drop('area_mean', axis=1)
feature_area_mean = pd.qcut(feature_area_mean, q=2, labels=range(2))

batch_size = 16
# cover the label column 'diagnosis'
x = np.array(df.iloc[:, 1:1+14])
x = torch.tensor(x)
x = x.float().cuda()
y_predict = model(x)
y_score = torch.nn.functional.softmax(y_predict)[:, :1]
y_score = y_score.reshape(-1).cpu().detach().numpy()
df['diagnosis'] = y_score
df[['diagnosis']] = sc.fit_transform(df[['diagnosis']])


# eval on dataset without the inferred labels as extra feature
print('eval on dataset WITHOUT the inferred labels as extra feature')
train_set_without_diagnosis = MyDataset(df=df, label_area_mean=feature_area_mean, use_diagnosis_as_feature=False, train=True)
train_dataloader_without_diagnosis = data.DataLoader(train_set_without_diagnosis, batch_size=16, shuffle=True)
test_set_without_diagnosis = MyDataset(df=df, label_area_mean=feature_area_mean, use_diagnosis_as_feature=False, train=False)
test_dataloader_without_diagnosis = data.DataLoader(test_set_without_diagnosis, batch_size=16, shuffle=True)

train_and_validata_model(train_loader=train_dataloader_without_diagnosis,
                         test_loader=test_dataloader_without_diagnosis,
                         model=MyModel(in_dim=14),
                         epochs=100)

print('\n\n')

# eval on dataset with the inferred labels as extra feature
print('eval on dataset WITH the inferred labels as extra feature')
train_set_with_diagnosis = MyDataset(df=df, label_area_mean=feature_area_mean, use_diagnosis_as_feature=True, train=True)
train_dataloader_with_diagnosis = data.DataLoader(train_set_with_diagnosis, batch_size=16, shuffle=True)
test_set_with_diagnosis = MyDataset(df=df, label_area_mean=feature_area_mean, use_diagnosis_as_feature=True, train=False)
test_dataloader_with_diagnosis = data.DataLoader(test_set_with_diagnosis, batch_size=16, shuffle=True)

train_and_validata_model(train_loader=train_dataloader_with_diagnosis,
                         test_loader=test_dataloader_with_diagnosis,
                         model=MyModel(in_dim=15),
                         epochs=100)
